#' \code{print.boot_glm}
#'
#' print output from a fitted model
#' @param x an object of class boot_glm
#' @param digits number of figures to print
#' @param ... other arguments to pass to print
#' @return the boot_glm argument, invisibly
#' @export
print.boot_glm <- function(x, digits = 2, ...)
{
  cat("\nCall:\n", paste(deparse(x$call), sep = "\n", collapse = "\n"),
    "\n\n", sep = "")
  cat("Results:\n")
  print.default(format(x$summary, digits = digits), print.gap = 2L,
    quote = FALSE)
  cat("\nN = ", x$N,
    "\n\n", sep = "")
  cat("\nNumber of Bootstrap Resamples = ", x$B,
    "\n\n", sep = "")
  invisible(x)
}

#' Bootstrapped GLM model
#'
#' Fit a bootstrapped GLM model
#' @param formula see glm
#' @param data see glm
#' @param cluster level at which conduct bootstrap resampling
#' @param B number of bootstrap resamples
#' @param return_boots logical, save resamples?
#' @param seed random seed
#' @param family see glm
#' @return an object of class boot_glm
#' @importFrom parallel detectCores mclapply
#' @export
boot_glm <- function(formula, data, cluster = "type", B = 1000,
  return_boots = TRUE, seed = sample.int(.Machine$integer.max, 1),
  family = gaussian)
{
  set.seed(seed)
  cl <- match.call()
  model <- glm(formula, data, family = family)
  setDT(data)
  if (!is.null(cluster)) {
    cluster <- data[[cluster]]
  } else {
    cluster <- seq_len(nrow(data))
  }
  idx <- lapply(unique(cluster), function(x) which(cluster == x))
  boot <- function(fakearg)
  {
    sample_idx <- lapply(idx, sample, replace = TRUE)
    ok <- do.call(c, sample_idx)
    boot_model <- glm(formula, data = data[ok, ], family = family)
    coef(boot_model)
  }
  boots <- do.call(rbind, parallel::mclapply(seq_len(B), boot,
    mc.cores = parallel::detectCores()))
  summ <- cbind(
    coef(model),
    apply(boots, 2, sd),
    apply(boots, 2, quantile, prob = 0.025),
    apply(boots, 2, quantile, prob = 0.975),
    apply(boots, 2, function(x) 2 * min(mean(x < 0), mean(x > 0))))
  colnames(summ) <- c("coef", "serr", "q025", "q975", "pval")
  ret <- list(call = cl, summary = summ, N = nrow(data), B = B, seed = seed)
  if (return_boots) {
    ret$boots <- boots
  }
  class(ret) <- "boot_glm"
  ret
}

#' Generate baseline experimental data for cheapchat
#'
#' A treatment-session pair uniquely identifies a physical lab session.
#' A treatment-session-id triple uniquely identifies a physical subject.
#' A treatment-session-id-module-round uniquely identifies a pairing.
#' Within each treatment-session, id's are randomly assigned to be either
#' sender or receiver.
#' Within each treatment-session-module-round-role, groups are assigned to
#' integers between 1 and n_subjects_per_session / 2.
#' Targets are specified by module and round and assumed to be integers
#' uniformly drawn from between -100 and 100.
#' @param n_treatments integer integer, number of treatments
#' @param n_modules_per_session integer, number modules per session
#' @param n_subjects_per_session integer, number subjects per_session--
#'  should be even
#' @param n_sessions_per_treatment integer, number sessions per treatment
#' @param n_rounds_per_module integer, number rounds per module
#' @return data.table of baseline experimental data
#' @import data.table
#' @export
simulate_baseline <- function(n_treatments, n_sessions_per_treatment,
  n_subjects_per_session, n_modules_per_session, n_rounds_per_module)
{
  stopifnot(n_subjects_per_session %% 2 == 0)
  data <- CJ(
    treatment = seq_len(n_treatments),
    session = seq_len(n_sessions_per_treatment),
    id = seq_len(n_subjects_per_session),
    module = seq_len(n_modules_per_session),
    round = seq_len(n_rounds_per_module))
  data[, role := ifelse(
    id %in% sample.int(n_subjects_per_session, n_subjects_per_session / 2),
    "sender", "receiver"),
    by = .(treatment, session)]
  data[, group := sample.int(n_subjects_per_session / 2),
    by = .(treatment, session, module, round, role)]
  data[, target := sample(seq(-100, 100, 1), 1),
    .(treatment, session, module, round, group)]
  data
}

#' arctanh
#'
#' 1 / 2 * log((1 + x) / (1 - x))
#' @param x real number
#' @return arctanh of x
#' @export
#'
arctanh <- function(x)
{
  1 / 2 * log((1 + x) / (1 - x))
}



make_reg_table <- function(model_list, model_names, coef_order, group_order,
  coef_names, label, group_names, title, scalebox, notes)
{
  b_list <- lapply(model_list, fixef)

  varnames_list <- lapply(b_list, names)
  se_list <- lapply(model_list, function(x) diag(vcov(x)) ^ .5)
  q025_list <- mapply(function(x, y) x$stan_summary[y, "2.5%"],
    model_list, varnames_list, SIMPLIFY = FALSE)
  q975_list <- mapply(function(x, y) x$stan_summary[y, "97.5%"],
    model_list, varnames_list, SIMPLIFY = FALSE)
  star_list <- mapply(function(q025, q975)
    ifelse(sign(q025) == sign(q975), "^*", "  "), q025_list, q975_list,
    SIMPLIFY = FALSE)
  B_list <- mapply(function(x, y) paste0(sprintf("%.02f", x), y),
    b_list, star_list, SIMPLIFY = FALSE)
  S_list <- lapply(se_list, function(x) paste0("(", sprintf("%.02f", x), ")"))

  n_list <- lapply(model_list, function(x) ngrps(x))
  groupnames_list <- lapply(n_list, names)
  sd_list <- lapply(model_list,
    function(x) c(unlist(VarCorr(x)) ^ .5, sigma(x)))
  SD_list <- lapply(sd_list, function(x) sprintf("%.02f", x))

  cov_names <- unique(unlist(lapply(model_list, function(x) names(fixef(x)))))
  k1 <- length(cov_names)
  stopifnot(length(coef_names) == length(cov_names))
  map_list <- lapply(b_list, function(x) match(names(x), cov_names))
  sync <- function(x, y) {
    out <- rep("", k1)
    out[y] <- x
    out
  }
  B_list <- mapply(sync, B_list, map_list, SIMPLIFY = FALSE)
  S_list <- mapply(sync, S_list, map_list, SIMPLIFY = FALSE)

  obs_names <- unique(unlist(lapply(n_list, function(x) names(x))))
  k2 <- length(obs_names)
  map_list <- lapply(n_list, function(x) match(names(x), obs_names))
  sync <- function(x, y) {
    out <- rep("", k2)
    out[y] <- x
    out
  }
  n_list <- mapply(sync, n_list, map_list, SIMPLIFY = FALSE)

  grp_names <- unique(unlist(lapply(sd_list, function(x) names(x))))
  grp_names <- c(setdiff(grp_names, ""), "")
  k3 <- length(grp_names)
  stopifnot(length(grp_names) == length(group_names) + 1)
  map_list <- lapply(sd_list, function(x) match(names(x), grp_names))
  sync <- function(x, y) {
    out <- rep("", k3)
    out[y] <- x
    out
  }
  sd_list <- mapply(sync, sd_list, map_list, SIMPLIFY = FALSE)
  SD_list <- mapply(sync, SD_list, map_list, SIMPLIFY = FALSE)
  group_names <- c(group_names, "Residual")
  model_names <- paste0(model_names, collapse = "&")
  n <- length(b_list)
  m <- length(group_names)
  just <- paste0("l", paste0(rep("c", n), collapse = ""))

  # coef_names_lengths <- sapply(coef_names, length)
  # coef_padding <- max(coef_names_lengths) + 1 - coef_names_lengths
  # coef_names <- sapply(seq_along(coef_names), function(i)
  #   paste0("    ", paste0(coef_names[i], rep(" ", coef_padding[i]))))


  interleave <- function(v1, v2)
  {
    if (length(v1) == 1) {
      c(v1, v2)
    } else {
      c(interleave(head(v1, -1), head(v2, -1)), tail(v1, 1), tail(v2, 1))
    }
  }
  part1a <- #paste0(
    sapply(seq_along(coef_order), function(i)
      paste0(coef_names[i], "&$",
        paste0(sapply(B_list, function(x) x[i]), collapse = "$&$"), "$\\\\",
        collapse = ""))#, collapse = "\n")

  part1b <- #paste0(
    sapply(seq_along(coef_order), function(i)
      paste0("&$",
        paste0(sapply(S_list, function(x) x[i]), collapse = "$&$"), "$\\\\",
        collapse = ""))#, collapse = "\n")

  part1 <- paste0(interleave(part1a, part1b), collapse = "\n")
  part2 <- paste0(sapply(seq_len(k2), function(i)
    paste0("$n$ ", group_names[group_order[i]], "&$",
      paste0(sapply(n_list, function(x) x[group_order[i]]), collapse = "$&$"),
      "$\\\\", collapse = "")), collapse = "\n")
  part3 <- paste0(sapply(seq_len(k3), function(i)
    paste0(group_names[group_order[i]], "&$",
      paste0(sapply(SD_list, function(x) x[group_order[i]]), collapse = "$&$"),
      "$\\\\", collapse = "")), collapse = "\n")
  tab <- paste0("
    \\begin{table}[t] \\footnotesize
    \\begin{center}
    \\caption{", title, "}
    \\label{", label, "} \\vspace{1em}
    \\scalebox{", scalebox, "}{
    \\begin{threeparttable}
    \\begin{tabular}{", just, "}  \\vspace{.1em}
    & ", model_names, " \\\\
    \\hline
    ", part1, "
    \\hline
    $n$ Observations&$",
    paste0(sapply(model_list, nobs), collapse = "$&$"), "$\\\\
    ", part2, "
    \\hline
    \\emph{Error terms} & \\multicolumn{", n, "}{c}{\\underline{Group SD}}\\\\
    ", part3, "
    \\hline \\hline
    \\end{tabular}
    \\begin{tablenotes}
    \\item{\\footnotesize
    ", notes, "
    }
    \\end{tablenotes}
    \\end{threeparttable}
    }
    \\end{center}
    \\end{table}
    ")
  cat(tab)#, file = paste0("results/", label, ".tex"))
  # invisible(tab)
}



calc_effect <- function(m, var, prettyvar) {
  X1 <- model.frame(m)
  X0 <- model.frame(m)
  X1[, var] <- 1
  X0[, var] <- 0
  change <- predict(m, X1) - predict(m, newdata = X0)
  out <- data.table(variable = prettyvar,
    estimate = mean(change),
    se = sd(change)/sqrt(length(change)),
    q25 = quantile(change, .25),
    q75 = quantile(change, .75))
  out[, q025 := estimate + qnorm(.025) * se]
  out[, q975 := estimate + qnorm(.975) * se]
  out
}